import os
os.chdir("../")
import sys
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
import numpy as np
import matplotlib.pyplot as plt

#RESULT_PATH = "results/thesis_submission_results/"
RESULT_PATH = "results/multi_booths_old/1/"

NUM_RUNS = 4
NUM_EPISODES = 12000
episodes_indices = [i for i in range(0, NUM_EPISODES, 20)]

obl_mi = np.load(RESULT_PATH + "obl_mi_reward_mi_log2_argmax.npy")
obl_mi_mean = np.mean(obl_mi, axis = 0)
obl_mi_std = np.std(obl_mi, axis = 0) / np.sqrt(NUM_RUNS)
print(np.mean(obl_mi_mean))

obl_mi_loss = np.load(RESULT_PATH + "obl_mi_reward_mi_loss_argmax.npy")
obl_mi_loss_mean = np.mean(obl_mi_loss, axis = 0)
obl_mi_loss_std = np.std(obl_mi_loss, axis = 0) / np.sqrt(NUM_RUNS)
print(np.mean(obl_mi_loss))

obl_mi_mi_loss = np.load(RESULT_PATH + "obl_mi_reward_mi_log2_mi_loss_argmax.npy")
obl_mi_mi_loss_mean = np.mean(obl_mi_mi_loss, axis = 0)
obl_mi_mi_loss_std = np.std(obl_mi_mi_loss, axis = 0) / np.sqrt(NUM_RUNS)
print(np.mean(obl_mi_mi_loss_mean))


CB91_Blue = '#2CBDFE'
CB91_Green = '#47DBCD'
CB91_Pink = '#F3A0F2'
CB91_Purple = '#9D2EC5'
CB91_Violet = '#661D98'
CB91_Amber = '#F5B14C'
seventh_color = "#4b97ec"
color_list = [CB91_Blue, CB91_Pink, CB91_Green, CB91_Amber, CB91_Purple, CB91_Violet, seventh_color]

# Plot Mean
plt.plot(episodes_indices, obl_mi_mean.squeeze(), label = "OBL + MI Reward", color = color_list[0])
plt.fill_between(episodes_indices, obl_mi_mean-obl_mi_std, obl_mi_mean+obl_mi_std, facecolor = color_list[0], alpha = 0.3)

plt.plot(episodes_indices, obl_mi_loss_mean.squeeze(), label = "OBL + MI Loss", color = color_list[1])
plt.fill_between(episodes_indices, obl_mi_loss_mean-obl_mi_loss_std, obl_mi_loss_mean+obl_mi_loss_std, facecolor = color_list[1], alpha = 0.3)

plt.plot(episodes_indices, obl_mi_mi_loss_mean.squeeze(), label = "OBL + MI Reward + MI Loss", color = color_list[2])
plt.fill_between(episodes_indices, obl_mi_mi_loss_mean-obl_mi_mi_loss_std, obl_mi_mi_loss_mean+obl_mi_mi_loss_std, facecolor = color_list[2], alpha = 0.3)


ax = plt.gca()
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)

plt.legend()
plt.ylabel("Running MI Reward")
plt.xlabel("Episodes")
plt.title("Hyperparameters Sweep using IQL")
plt.show()
